from dataclasses import dataclass

from torchtyping import TensorType


@dataclass
class PPORLElement:
    """
    RLElement for PPO

    :param query_tensor: The query tensor i.e. the prompt tokens. Should be a long tensor.
    :type query_tensor: torch.Tensor

    :param response_tensor: The response tensor i.e. the output tokens. Should be a long tensor.
    :type response_tensor: torch.Tensor

    :param logprobs: The log probabilities over all tokens in the vocabulary for each token generated from the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens, with a dimension across the vocabulary.
    :type logprobs: torch.Tensor

    :param values: The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens.
    :type values: torch.Tensor

    :param rewards: The rewards for each token outputted in response. Should be a float tensor of same size as tokens.
    :type rewards: torch.Tensor
    """

    query_tensor: TensorType["query_size"]
    response_tensor: TensorType["response_size"]
    logprobs: TensorType["response_size", "vocab_size"]
    values: TensorType["response_size"]
    rewards: TensorType["response_size"]
    score_train: TensorType[()]
    input_ids_mixin: TensorType["alltoken_size"] = None
    attention_mask_mixin: TensorType["alltoken_size"] = None
    token_type_ids_mixin: TensorType["alltoken_size"] = None

@dataclass
class PPORLBatch:
    """
    A batched version of the PPORLElement. See PPORLElement for more details on individual fields.

    :param query_tensors: A batch of query tensors. Should be a long tensor.
    :type query_tensors: torch.Tensor

    :param response_tensors: A batch of response tensors. Should be a long tensor.
    :type response_tensors: torch.Tensor

    :param logprobs: A batch of log probabilities from policy
    :type logprobs: torch.Tensor

    :param values: A batch of values from value network
    :type values: torch.Tensor

    :param rewards: A batch of rewards
    :type rewards: torch.Tensor
    """

    query_tensors: TensorType["batch_size", "query_size"]
    response_tensors: TensorType["batch_size", "response_size"]
    logprobs: TensorType["batch_size", "response_size", "vocab_size"]
    values: TensorType["batch_size", "response_size"]
    rewards: TensorType["batch_size", "response_size"]
    score_train: TensorType["batch_size"]
    input_ids_mixin: TensorType["batch_size", "alltoken_size"] = None
    attention_mask_mixin: TensorType["batch_size", "alltoken_size"] = None
    token_type_ids_mixin: TensorType["batch_size", "alltoken_size"] = None
